{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 机器翻译与数据集\n",
    ":label:`sec_machine_translation`\n",
    "\n",
    "语言模型是自然语言处理的关键，\n",
    "而*机器翻译*是语言模型最成功的基准测试。\n",
    "因为机器翻译正是将输入序列转换成输出序列的\n",
    "*序列转换模型*（sequence transduction）的核心问题。\n",
    "序列转换模型在各类现代人工智能应用中发挥着至关重要的作用，\n",
    "因此我们将其做为本章剩余部分和 :numref:`chap_attention`的重点。\n",
    "为此，本节将介绍机器翻译问题及其后文需要使用的数据集。\n",
    "\n",
    "*机器翻译*（machine translation）指的是\n",
    "将序列从一种语言自动翻译成另一种语言。\n",
    "事实上，这个研究领域可以追溯到数字计算机发明后不久的20世纪40年代，\n",
    "特别是在第二次世界大战中使用计算机破解语言编码。\n",
    "几十年来，在使用神经网络进行端到端学习的兴起之前，\n",
    "统计学方法在这一领域一直占据主导地位\n",
    " :cite:`Brown.Cocke.Della-Pietra.ea.1988,Brown.Cocke.Della-Pietra.ea.1990`。\n",
    "因为*统计机器翻译*（statisticalmachine translation）涉及了\n",
    "翻译模型和语言模型等组成部分的统计分析，\n",
    "因此基于神经网络的方法通常被称为\n",
    "*神经机器翻译*（neuralmachine translation），\n",
    "用于将两种翻译模型区分开来。\n",
    "\n",
    "本书的关注点是神经网络机器翻译方法，强调的是端到端的学习。\n",
    "与 :numref:`sec_language_model`中的语料库\n",
    "是单一语言的语言模型问题存在不同，\n",
    "机器翻译的数据集是由源语言和目标语言的文本序列对组成的。\n",
    "因此，我们需要一种完全不同的方法来预处理机器翻译数据集，\n",
    "而不是复用语言模型的预处理程序。\n",
    "下面，我们看一下如何将预处理后的数据加载到小批量中用于训练。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModel.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java\n",
    "%load ../utils/timemachine/TimeMachineDataset.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.charset.*;\n",
    "import java.util.zip.*;\n",
    "import java.util.stream.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 下载和预处理数据集\n",
    "\n",
    "首先，下载一个由[Tatoeba项目的双语句子对](http://www.manythings.org/anki/)\n",
    "组成的“英－法”数据集，数据集中的每一行都是制表符分隔的文本序列对，\n",
    "序列对由英文文本序列和翻译后的法语文本序列组成。\n",
    "请注意，每个文本序列可以是一个句子，\n",
    "也可以是包含多个句子的一个段落。\n",
    "在这个将英语翻译成法语的机器翻译问题中，\n",
    "英语是*源语言*（source language），\n",
    "法语是*目标语言*（target language）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String readDataNMT() throws IOException {\n",
    "    DownloadUtils.download(\n",
    "            \"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip\", \"fra-eng.zip\");\n",
    "    ZipFile zipFile = new ZipFile(new File(\"fra-eng.zip\"));\n",
    "    Enumeration<? extends ZipEntry> entries = zipFile.entries();\n",
    "    while (entries.hasMoreElements()) {\n",
    "        ZipEntry entry = entries.nextElement();\n",
    "        if (entry.getName().contains(\"fra.txt\")) {\n",
    "            InputStream stream = zipFile.getInputStream(entry);\n",
    "            return new String(stream.readAllBytes(), StandardCharsets.UTF_8);\n",
    "        }\n",
    "    }\n",
    "    return null;\n",
    "}\n",
    "\n",
    "String rawText = readDataNMT();\n",
    "System.out.println(rawText.substring(0, 75));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下载数据集后，原始文本数据需要经过几个预处理步骤。\n",
    "例如，我们用空格代替*不间断空格*（non-breaking space），\n",
    "使用小写字母替换大写字母，并在单词和标点符号之间插入空格。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String preprocessNMT(String text) {\n",
    "    // 使用空格替换不间断空格\n",
    "    // 使用小写字母替换大写字母\n",
    "    text = text.replace('\\u202f', ' ').replaceAll(\"\\\\xa0\", \" \").toLowerCase();\n",
    "\n",
    "    // 在单词和标点符号之间插入空格\n",
    "    StringBuilder out = new StringBuilder();\n",
    "    Character currChar;\n",
    "    for (int i = 0; i < text.length(); i++) {\n",
    "        currChar = text.charAt(i);\n",
    "        if (i > 0 && noSpace(currChar, text.charAt(i - 1))) {\n",
    "            out.append(' ');\n",
    "        }\n",
    "        out.append(currChar);\n",
    "    }\n",
    "    return out.toString();\n",
    "}\n",
    "\n",
    "public static boolean noSpace(Character currChar, Character prevChar) {\n",
    "    /* Preprocess the English-French dataset. */\n",
    "    return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar)\n",
    "            && prevChar != ' ';\n",
    "}\n",
    "\n",
    "String text = preprocessNMT(rawText);\n",
    "System.out.println(text.substring(0, 80));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 词元化\n",
    "\n",
    "与 :numref:`sec_language_model`中的字符级词元化不同，\n",
    "在机器翻译中，我们更喜欢单词级词元化\n",
    "（最先进的模型可能使用更高级的词元化技术）。\n",
    "下面的`tokenize_nmt`函数对前`num_examples`个文本序列对进行词元，\n",
    "其中每个词元要么是一个词，要么是一个标点符号。\n",
    "此函数返回两个词元列表：`source`和`target`：\n",
    "`source[i]`是源语言（这里是英语）第$i$个文本序列的词元列表，\n",
    "`target[i]`是目标语言（这里是法语）第$i$个文本序列的词元列表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayList<String[]>, ArrayList<String[]>> tokenizeNMT(\n",
    "        String text, Integer numExamples) {\n",
    "    ArrayList<String[]> source = new ArrayList<>();\n",
    "    ArrayList<String[]> target = new ArrayList<>();\n",
    "\n",
    "    int i = 0;\n",
    "    for (String line : text.split(\"\\n\")) {\n",
    "        if (numExamples != null && i > numExamples) {\n",
    "            break;\n",
    "        }\n",
    "        String[] parts = line.split(\"\\t\");\n",
    "        if (parts.length == 2) {\n",
    "            source.add(parts[0].split(\" \"));\n",
    "            target.add(parts[1].split(\" \"));\n",
    "        }\n",
    "        i += 1;\n",
    "    }\n",
    "    return new Pair<>(source, target);\n",
    "}\n",
    "\n",
    "Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text.toString(), null);\n",
    "ArrayList<String[]> source = pair.getKey();\n",
    "ArrayList<String[]> target = pair.getValue();\n",
    "for (String[] subArr : source.subList(0, 6)) {\n",
    "    System.out.println(Arrays.toString(subArr));\n",
    "}\n",
    "\n",
    "for (String[] subArr : target.subList(0, 6)) {\n",
    "    System.out.println(Arrays.toString(subArr));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们绘制每个文本序列所包含的词元数量的直方图。\n",
    "在这个简单的“英－法”数据集中，大多数文本序列的词元数量少于$20$个。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] y1 = new double[source.size()];\n",
    "for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length;\n",
    "double[] y2 = new double[target.size()];\n",
    "for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length;\n",
    "\n",
    "HistogramTrace trace1 =\n",
    "        HistogramTrace.builder(y1).opacity(.75).name(\"source\").nBinsX(20).build();\n",
    "HistogramTrace trace2 =\n",
    "        HistogramTrace.builder(y2).opacity(.75).name(\"target\").nBinsX(20).build();\n",
    "\n",
    "Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build();\n",
    "new Figure(layout, trace1, trace2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 词表\n",
    "\n",
    "由于机器翻译数据集由语言对组成，\n",
    "因此我们可以分别为源语言和目标语言构建两个词表。\n",
    "使用单词级词元化时，词表大小将明显大于使用字符级词元化时的词表大小。\n",
    "为了缓解这一问题，这里我们将出现次数少于2次的低频率词元\n",
    "视为相同的未知（“&lt;unk&gt;”）词元。\n",
    "除此之外，我们还指定了额外的特定词元，\n",
    "例如在小批量时用于将序列填充到相同长度的填充词元（“&lt;pad&gt;”），\n",
    "以及序列的开始词元（“&lt;bos&gt;”）和结束词元（“&lt;eos&gt;”）。\n",
    "这些特殊词元在自然语言处理任务中比较常用。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Vocab srcVocab = new Vocab(\n",
    "        source.stream().toArray(String[][]::new),\n",
    "        2,\n",
    "        new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "System.out.println(srcVocab.length());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集\n",
    ":label:`subsec_mt_data_loading`\n",
    "\n",
    "回想一下，语言模型中的序列样本都有一个固定的长度，\n",
    "无论这个样本是一个句子的一部分还是跨越了多个句子的一个片断。\n",
    "这个固定长度是由 :numref:`sec_language_model`中的\n",
    "`numSteps`（时间步数或词元数量）参数指定的。\n",
    "在机器翻译中，每个样本都是由源和目标组成的文本序列对，\n",
    "其中的每个文本序列可能具有不同的长度。\n",
    "\n",
    "为了提高计算效率，我们仍然可以通过*截断*（truncation）和\n",
    "*填充*（padding）方式实现一次只处理一个小批量的文本序列。\n",
    "假设同一个小批量中的每个序列都应该具有相同的长度`numSteps`，\n",
    "那么如果文本序列的词元数目少于`numSteps`时，\n",
    "我们将继续在其末尾添加特定的“&lt;pad&gt;”词元，\n",
    "直到其长度达到`numSteps`；\n",
    "反之，我们将截断文本序列时，只取其前`numSteps` 个词元，\n",
    "并且丢弃剩余的词元。这样，每个文本序列将具有相同的长度，\n",
    "以便以相同形状的小批量进行加载。\n",
    "\n",
    "如前所述，下面的`truncatePad`函数将截断或填充文本序列。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static int[] truncatePad(Integer[] integerLine, int numSteps, int paddingToken) {\n",
    "    // 截断或填充文本序列\n",
    "    int[] line = Arrays.stream(integerLine).mapToInt(i -> i).toArray();\n",
    "    if (line.length > numSteps) {\n",
    "        return Arrays.copyOfRange(line, 0, numSteps);\n",
    "    }\n",
    "    int[] paddingTokenArr = new int[numSteps - line.length]; // Pad\n",
    "    Arrays.fill(paddingTokenArr, paddingToken);\n",
    "\n",
    "    return IntStream.concat(Arrays.stream(line), Arrays.stream(paddingTokenArr)).toArray();\n",
    "}\n",
    "\n",
    "int[] result = truncatePad(srcVocab.getIdxs(source.get(0)), 10, srcVocab.getIdx(\"<pad>\"));\n",
    "System.out.println(Arrays.toString(result));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们定义一个函数，可以将文本序列\n",
    "转换成小批量数据集用于训练。\n",
    "我们将特定的“&lt;eos&gt;”词元添加到所有序列的末尾，\n",
    "用于表示序列的结束。\n",
    "当模型通过一个词元接一个词元地生成序列进行预测时，\n",
    "生成的“&lt;eos&gt;”词元说明完成了序列输出工作。\n",
    "此外，我们还记录了每个文本序列的长度，\n",
    "统计长度时排除了填充词元，\n",
    "在稍后将要介绍的一些模型会需要这个长度信息。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<NDArray, NDArray> buildArrayNMT(\n",
    "        List<String[]> lines, Vocab vocab, int numSteps) {\n",
    "    // 将机器翻译的文本序列转换成小批量\n",
    "    List<Integer[]> linesIntArr = new ArrayList<>();\n",
    "    for (String[] strings : lines) {\n",
    "        linesIntArr.add(vocab.getIdxs(strings));\n",
    "    }\n",
    "    for (int i = 0; i < linesIntArr.size(); i++) {\n",
    "        List<Integer> temp = new ArrayList<>(Arrays.asList(linesIntArr.get(i)));\n",
    "        temp.add(vocab.getIdx(\"<eos>\"));\n",
    "        linesIntArr.set(i, temp.toArray(new Integer[0]));\n",
    "    }\n",
    "\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "    NDArray arr = manager.create(new Shape(linesIntArr.size(), numSteps), DataType.INT32);\n",
    "    int row = 0;\n",
    "    for (Integer[] line : linesIntArr) {\n",
    "        NDArray rowArr = manager.create(truncatePad(line, numSteps, vocab.getIdx(\"<pad>\")));\n",
    "        arr.set(new NDIndex(\"{}:\", row), rowArr);\n",
    "        row += 1;\n",
    "    }\n",
    "    NDArray validLen = arr.neq(vocab.getIdx(\"<pad>\")).sum(new int[] {1});\n",
    "    return new Pair<>(arr, validLen);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "最后，我们定义`loadDataNMT`函数来返回数据迭代器，\n",
    "以及源语言和目标语言的两种词表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayDataset, Pair<Vocab, Vocab>> loadDataNMT(\n",
    "        int batchSize, int numSteps, int numExamples) throws IOException {\n",
    "    // 返回翻译数据集的迭代器和词表\n",
    "    String text = preprocessNMT(readDataNMT());\n",
    "    Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text, numExamples);\n",
    "    ArrayList<String[]> source = pair.getKey();\n",
    "    ArrayList<String[]> target = pair.getValue();\n",
    "    Vocab srcVocab =\n",
    "            new Vocab(\n",
    "                    source.toArray(String[][]::new),\n",
    "                    2,\n",
    "                    new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "    Vocab tgtVocab =\n",
    "            new Vocab(\n",
    "                    target.toArray(String[][]::new),\n",
    "                    2,\n",
    "                    new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "\n",
    "    Pair<NDArray, NDArray> pairArr = buildArrayNMT(source, srcVocab, numSteps);\n",
    "    NDArray srcArr = pairArr.getKey();\n",
    "    NDArray srcValidLen = pairArr.getValue();\n",
    "\n",
    "    pairArr = buildArrayNMT(target, tgtVocab, numSteps);\n",
    "    NDArray tgtArr = pairArr.getKey();\n",
    "    NDArray tgtValidLen = pairArr.getValue();\n",
    "\n",
    "    ArrayDataset dataset =\n",
    "            new ArrayDataset.Builder()\n",
    "                    .setData(srcArr, srcValidLen)\n",
    "                    .optLabels(tgtArr, tgtValidLen)\n",
    "                    .setSampling(batchSize, true)\n",
    "                    .build();\n",
    "\n",
    "    return new Pair<>(dataset, new Pair<>(srcVocab, tgtVocab));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面我们读出“英语－法语”数据集中的第一个小批量数据。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pair<ArrayDataset, Pair<Vocab, Vocab>> output = loadDataNMT(2, 8, 600);\n",
    "ArrayDataset dataset = output.getKey();\n",
    "srcVocab = output.getValue().getKey();\n",
    "Vocab tgtVocab = output.getValue().getValue();\n",
    "\n",
    "Batch batch = dataset.getData(manager).iterator().next();\n",
    "NDArray X = batch.getData().get(0);\n",
    "NDArray xValidLen = batch.getData().get(1);\n",
    "NDArray Y = batch.getLabels().get(0);\n",
    "NDArray yValidLen = batch.getLabels().get(1);\n",
    "System.out.println(X);\n",
    "System.out.println(xValidLen);\n",
    "System.out.println(Y);\n",
    "System.out.println(yValidLen);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 机器翻译指的是将文本序列从一种语言自动翻译成另一种语言。\n",
    "* 使用单词级词元化时的词表大小，将明显大于使用字符级词元化时的词表大小。为了缓解这一问题，我们可以将低频词元视为相同的未知词元。\n",
    "* 通过截断和填充文本序列，可以保证所有的文本序列都具有相同的长度，以便以小批量的方式加载。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 在`load_data_nmt`函数中尝试不同的`num_examples`参数值。这对源语言和目标语言的词表大小有何影响？\n",
    "1. 某些语言（例如中文和日语）的文本没有单词边界指示符（例如空格）。对于这种情况，单词级词元化仍然是个好主意吗？为什么？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
